{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPU Extreme gradient boosting trained on timestamp text data-set\n",
    "\n",
    "1. Same emotion dataset from [NLP-dataset](https://github.com/huseinzol05/NLP-Dataset)\n",
    "2. Same splitting 80% training, 20% testing, may vary depends on randomness\n",
    "3. Same regex substitution '[^\\\"\\'A-Za-z0-9 ]+'\n",
    "\n",
    "## Example\n",
    "\n",
    "Based on sorted dictionary position\n",
    "\n",
    "text: 'module into which all the refactored classes', matrix: [167, 143, 12, 3, 4, 90]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sklearn.datasets\n",
    "import re\n",
    "import time\n",
    "import xgboost as xgb\n",
    "import pickle\n",
    "from sklearn.cross_validation import train_test_split\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def clearstring(string):\n",
    "    string = re.sub('[^\\\"\\'A-Za-z0-9 ]+', '', string)\n",
    "    string = string.split(' ')\n",
    "    string = filter(None, string)\n",
    "    string = [y.strip() for y in string]\n",
    "    string = ' '.join(string)\n",
    "    return string\n",
    "\n",
    "# because of sklean.datasets read a document as a single element\n",
    "# so we want to split based on new line\n",
    "def separate_dataset(trainset):\n",
    "    datastring = []\n",
    "    datatarget = []\n",
    "    for i in range(len(trainset.data)):\n",
    "        data_ = trainset.data[i].split('\\n')\n",
    "        # python3, if python2, just remove list()\n",
    "        data_ = list(filter(None, data_))\n",
    "        for n in range(len(data_)):\n",
    "            data_[n] = clearstring(data_[n])\n",
    "        datastring += data_\n",
    "        for n in range(len(data_)):\n",
    "            datatarget.append(trainset.target[i])\n",
    "    return datastring, datatarget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trainset_data = sklearn.datasets.load_files(container_path = 'data', encoding = 'UTF-8')\n",
    "trainset_data.data, trainset_data.target = separate_dataset(trainset_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('dictionary_emotion.p', 'rb') as fopen:\n",
    "    dict_emotion = pickle.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "len_sentences = np.array([len(i.split()) for i in trainset_data.data])\n",
    "maxlen = np.ceil(len_sentences.mean()).astype('int')\n",
    "data_X = np.zeros((len(trainset_data.data), maxlen))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for i in range(data_X.shape[0]):\n",
    "    tokens = trainset_data.data[i].split()[:maxlen]\n",
    "    for no, text in enumerate(tokens[::-1]):\n",
    "        try:\n",
    "            data_X[i, -1 - no] = dict_emotion[text]\n",
    "        except:\n",
    "            continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(data_X, trainset_data.target, test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params_xgd = {\n",
    "    'min_child_weight': 10.0,\n",
    "    'objective': 'multi:softprob',\n",
    "    'eval_metric': 'mlogloss',\n",
    "    'num_class': len(trainset_data.target_names),\n",
    "    'max_depth': 7,\n",
    "    'max_delta_step': 1.8,\n",
    "    'colsample_bytree': 0.4,\n",
    "    'subsample': 0.8,\n",
    "    'eta': 0.03,\n",
    "    'gamma': 0.65,\n",
    "    'num_boost_round' : 700,\n",
    "    'gpu_id': 0,\n",
    "    'tree_method': 'gpu_hist'\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\ttrain-mlogloss:1.78342\tvalid-mlogloss:1.78357\n",
      "Multiple eval metrics have been passed: 'valid-mlogloss' will be used for early stopping.\n",
      "\n",
      "Will train until valid-mlogloss hasn't improved in 200 rounds.\n",
      "[100]\ttrain-mlogloss:1.53612\tvalid-mlogloss:1.5488\n",
      "[200]\ttrain-mlogloss:1.49598\tvalid-mlogloss:1.51999\n",
      "[300]\ttrain-mlogloss:1.4688\tvalid-mlogloss:1.50397\n",
      "[400]\ttrain-mlogloss:1.4456\tvalid-mlogloss:1.49123\n",
      "[500]\ttrain-mlogloss:1.42472\tvalid-mlogloss:1.48043\n",
      "[600]\ttrain-mlogloss:1.40576\tvalid-mlogloss:1.47118\n",
      "[700]\ttrain-mlogloss:1.38838\tvalid-mlogloss:1.46322\n",
      "[800]\ttrain-mlogloss:1.37113\tvalid-mlogloss:1.45556\n",
      "[900]\ttrain-mlogloss:1.35507\tvalid-mlogloss:1.44848\n",
      "[1000]\ttrain-mlogloss:1.34032\tvalid-mlogloss:1.44257\n",
      "[1100]\ttrain-mlogloss:1.32587\tvalid-mlogloss:1.4369\n",
      "[1200]\ttrain-mlogloss:1.31201\tvalid-mlogloss:1.43168\n",
      "[1300]\ttrain-mlogloss:1.29876\tvalid-mlogloss:1.42681\n",
      "[1400]\ttrain-mlogloss:1.28579\tvalid-mlogloss:1.42213\n",
      "[1500]\ttrain-mlogloss:1.27327\tvalid-mlogloss:1.41799\n",
      "[1600]\ttrain-mlogloss:1.26141\tvalid-mlogloss:1.41412\n",
      "[1700]\ttrain-mlogloss:1.24962\tvalid-mlogloss:1.4104\n",
      "[1800]\ttrain-mlogloss:1.23825\tvalid-mlogloss:1.40704\n",
      "[1900]\ttrain-mlogloss:1.22706\tvalid-mlogloss:1.40358\n",
      "[2000]\ttrain-mlogloss:1.21617\tvalid-mlogloss:1.40044\n",
      "[2100]\ttrain-mlogloss:1.20566\tvalid-mlogloss:1.39749\n",
      "[2200]\ttrain-mlogloss:1.19544\tvalid-mlogloss:1.39469\n",
      "[2300]\ttrain-mlogloss:1.18525\tvalid-mlogloss:1.3921\n",
      "[2400]\ttrain-mlogloss:1.17524\tvalid-mlogloss:1.38955\n",
      "[2500]\ttrain-mlogloss:1.16541\tvalid-mlogloss:1.38699\n",
      "[2600]\ttrain-mlogloss:1.15606\tvalid-mlogloss:1.38482\n",
      "[2700]\ttrain-mlogloss:1.14673\tvalid-mlogloss:1.38252\n",
      "[2800]\ttrain-mlogloss:1.13774\tvalid-mlogloss:1.38056\n",
      "[2900]\ttrain-mlogloss:1.12899\tvalid-mlogloss:1.3786\n",
      "[3000]\ttrain-mlogloss:1.12022\tvalid-mlogloss:1.37677\n",
      "[3100]\ttrain-mlogloss:1.11163\tvalid-mlogloss:1.37499\n",
      "[3200]\ttrain-mlogloss:1.10306\tvalid-mlogloss:1.3732\n",
      "[3300]\ttrain-mlogloss:1.09481\tvalid-mlogloss:1.37171\n",
      "[3400]\ttrain-mlogloss:1.0866\tvalid-mlogloss:1.37014\n",
      "[3500]\ttrain-mlogloss:1.0785\tvalid-mlogloss:1.36869\n",
      "[3600]\ttrain-mlogloss:1.07058\tvalid-mlogloss:1.36723\n",
      "[3700]\ttrain-mlogloss:1.0628\tvalid-mlogloss:1.36593\n",
      "[3800]\ttrain-mlogloss:1.0551\tvalid-mlogloss:1.36468\n",
      "[3900]\ttrain-mlogloss:1.04756\tvalid-mlogloss:1.36343\n",
      "[4000]\ttrain-mlogloss:1.03996\tvalid-mlogloss:1.36214\n",
      "[4100]\ttrain-mlogloss:1.03258\tvalid-mlogloss:1.36091\n",
      "[4200]\ttrain-mlogloss:1.0252\tvalid-mlogloss:1.35981\n",
      "[4300]\ttrain-mlogloss:1.01807\tvalid-mlogloss:1.35885\n",
      "[4400]\ttrain-mlogloss:1.01114\tvalid-mlogloss:1.35787\n",
      "[4500]\ttrain-mlogloss:1.00419\tvalid-mlogloss:1.35693\n",
      "[4600]\ttrain-mlogloss:0.997363\tvalid-mlogloss:1.35613\n",
      "[4700]\ttrain-mlogloss:0.990603\tvalid-mlogloss:1.35519\n",
      "[4800]\ttrain-mlogloss:0.98389\tvalid-mlogloss:1.35446\n",
      "[4900]\ttrain-mlogloss:0.977365\tvalid-mlogloss:1.35373\n",
      "[5000]\ttrain-mlogloss:0.970875\tvalid-mlogloss:1.35296\n",
      "[5100]\ttrain-mlogloss:0.964538\tvalid-mlogloss:1.35221\n",
      "[5200]\ttrain-mlogloss:0.958267\tvalid-mlogloss:1.35157\n",
      "[5300]\ttrain-mlogloss:0.952025\tvalid-mlogloss:1.35101\n",
      "[5400]\ttrain-mlogloss:0.945794\tvalid-mlogloss:1.35031\n",
      "[5500]\ttrain-mlogloss:0.939654\tvalid-mlogloss:1.34974\n",
      "[5600]\ttrain-mlogloss:0.933496\tvalid-mlogloss:1.34926\n",
      "[5700]\ttrain-mlogloss:0.92751\tvalid-mlogloss:1.34867\n",
      "[5800]\ttrain-mlogloss:0.921477\tvalid-mlogloss:1.34815\n",
      "[5900]\ttrain-mlogloss:0.915558\tvalid-mlogloss:1.3478\n",
      "[6000]\ttrain-mlogloss:0.90989\tvalid-mlogloss:1.34738\n",
      "[6100]\ttrain-mlogloss:0.904207\tvalid-mlogloss:1.34696\n",
      "[6200]\ttrain-mlogloss:0.89875\tvalid-mlogloss:1.34668\n",
      "[6300]\ttrain-mlogloss:0.893171\tvalid-mlogloss:1.34636\n",
      "[6400]\ttrain-mlogloss:0.887585\tvalid-mlogloss:1.34594\n",
      "[6500]\ttrain-mlogloss:0.882105\tvalid-mlogloss:1.34557\n",
      "[6600]\ttrain-mlogloss:0.876626\tvalid-mlogloss:1.34521\n",
      "[6700]\ttrain-mlogloss:0.871228\tvalid-mlogloss:1.34488\n",
      "[6800]\ttrain-mlogloss:0.865906\tvalid-mlogloss:1.34464\n",
      "[6900]\ttrain-mlogloss:0.860721\tvalid-mlogloss:1.34432\n",
      "[7000]\ttrain-mlogloss:0.855532\tvalid-mlogloss:1.34408\n",
      "[7100]\ttrain-mlogloss:0.850394\tvalid-mlogloss:1.34399\n",
      "[7200]\ttrain-mlogloss:0.84524\tvalid-mlogloss:1.34375\n",
      "[7300]\ttrain-mlogloss:0.840221\tvalid-mlogloss:1.3436\n",
      "[7400]\ttrain-mlogloss:0.835169\tvalid-mlogloss:1.34331\n",
      "[7500]\ttrain-mlogloss:0.830177\tvalid-mlogloss:1.34316\n",
      "[7600]\ttrain-mlogloss:0.825337\tvalid-mlogloss:1.34298\n",
      "[7700]\ttrain-mlogloss:0.820564\tvalid-mlogloss:1.34287\n",
      "[7800]\ttrain-mlogloss:0.815755\tvalid-mlogloss:1.34275\n",
      "[7900]\ttrain-mlogloss:0.811022\tvalid-mlogloss:1.34266\n",
      "[8000]\ttrain-mlogloss:0.806267\tvalid-mlogloss:1.34254\n",
      "[8100]\ttrain-mlogloss:0.801612\tvalid-mlogloss:1.34252\n",
      "[8200]\ttrain-mlogloss:0.797026\tvalid-mlogloss:1.34253\n",
      "[8300]\ttrain-mlogloss:0.792455\tvalid-mlogloss:1.34246\n",
      "[8400]\ttrain-mlogloss:0.787938\tvalid-mlogloss:1.34245\n",
      "[8500]\ttrain-mlogloss:0.783446\tvalid-mlogloss:1.34248\n",
      "[8600]\ttrain-mlogloss:0.779003\tvalid-mlogloss:1.34253\n",
      "Stopping. Best iteration:\n",
      "[8429]\ttrain-mlogloss:0.786625\tvalid-mlogloss:1.34242\n",
      "\n",
      "1010.187 Seconds to train xgb\n"
     ]
    }
   ],
   "source": [
    "d_train = xgb.DMatrix(train_X, train_Y)\n",
    "d_valid = xgb.DMatrix(test_X, test_Y)\n",
    "watchlist = [(d_train, 'train'), (d_valid, 'valid')]\n",
    "#with open('clf.p', 'rb') as fopen:\n",
    "#    clf = pickle.load(fopen)\n",
    "t=time.time()\n",
    "clf = xgb.train(params_xgd, d_train, 100000, watchlist, early_stopping_rounds=200, maximize=False, verbose_eval=100)\n",
    "print(round(time.time()-t, 3), 'Seconds to train xgb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.47465271946450421"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(test_Y == np.argmax(clf.predict(xgb.DMatrix(test_X), ntree_limit=clf.best_ntree_limit), axis = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "clf.save_model('0001.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "bst = xgb.Booster(params_xgd)\n",
    "bst.load_model('0001.model')\n",
    "with open('xgb-timestamp-param', 'w') as fopen:\n",
    "    fopen.write(json.dumps(params_xgd))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.47461673184424558"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(test_Y == np.argmax(bst.predict(xgb.DMatrix(test_X)), axis = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "      anger       0.48      0.22      0.31     11390\n",
      "       fear       0.46      0.20      0.28      9759\n",
      "        joy       0.48      0.72      0.58     27981\n",
      "       love       0.26      0.08      0.12      6838\n",
      "    sadness       0.49      0.58      0.53     24395\n",
      "   surprise       0.21      0.07      0.11      2999\n",
      "\n",
      "avg / total       0.45      0.47      0.44     83362\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "print(metrics.classification_report(test_Y, np.argmax(bst.predict(xgb.DMatrix(test_X)), axis = 1), target_names = trainset_data.target_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
